library(uwot)
context("perplexity")


# Full neighbor values based on comparison with smallvis results

iris10_nn10 <- dist_nn(dist(iris10), k = 10)




P_symm <- matrix(c(
  0.000000e+00, 0.0022956859, 0.0022079944, 0.0004763074, 4.338953e-02, 1.822079e-02, 0.002913239, 0.0413498285, 5.184416e-05, 0.004134502,
  2.295686e-03, 0.0000000000, 0.0188919615, 0.0129934442, 1.089032e-03, 5.689921e-04, 0.002131646, 0.0048261793, 6.996252e-03, 0.050676976,
  2.207994e-03, 0.0188919615, 0.0000000000, 0.0444964580, 2.464225e-03, 5.935835e-04, 0.040353636, 0.0027720360, 1.111298e-02, 0.013673490,
  4.763074e-04, 0.0129934442, 0.0444964580, 0.0000000000, 5.466455e-04, 2.771325e-04, 0.018028275, 0.0005761904, 3.389471e-02, 0.014363302,
  4.338953e-02, 0.0010890318, 0.0024642250, 0.0005466455, 0.000000e+00, 1.831834e-02, 0.006393040, 0.0329052015, 5.372241e-05, 0.002356628,
  1.822079e-02, 0.0005689921, 0.0005935835, 0.0002771325, 1.831834e-02, 0.000000e+00, 0.001326343, 0.0110122168, 1.065771e-05, 0.001168212,
  2.913239e-03, 0.0021316462, 0.0403536359, 0.0180282748, 6.393040e-03, 1.326343e-03, 0.000000000, 0.0059083283, 4.862680e-03, 0.002656313,
  4.134983e-02, 0.0048261793, 0.0027720360, 0.0005761904, 3.290520e-02, 1.101222e-02, 0.005908328, 0.0000000000, 2.982247e-04, 0.012212476,
  5.184416e-05, 0.0069962518, 0.0111129834, 0.0338947056, 5.372241e-05, 1.065771e-05, 0.004862680, 0.0002982247, 0.000000e+00, 0.004150755,
  4.134502e-03, 0.0506769759, 0.0136734904, 0.0143633019, 2.356628e-03, 1.168212e-03, 0.002656313, 0.0122124758, 4.150755e-03, 0.000000000
) * 10, nrow = 10, byrow = TRUE)


res <- perplexity_similarities(
  perplexity = 4, verbose = FALSE,
  nn = find_nn(iris10,
    k = 10, method = "fnn",
    metric = "euclidean", n_threads = 0,
    verbose = FALSE
  )
)
expect_true(Matrix::isSymmetric(res))
expect_equal(as.matrix(res), P_symm, tol = 1e-5, check.attributes = FALSE)

Psymm9 <- matrix(
  c(
    0, 0.1111, 0.1112, 0.1110, 0.1116, 0.1113, 0.1112, 0.1115, 0.1106, 0.1112,
    0.1111, 0, 0.1113, 0.1113, 0.1110, 0.1107, 0.1112, 0.1112, 0.1112, 0.1114,
    0.1112, 0.1113, 0, 0.1113, 0.1112, 0.1106, 0.1114, 0.1112, 0.1112, 0.1113,
    0.1110, 0.1113, 0.1113, 0, 0.1110, 0.1105, 0.1114, 0.1111, 0.1113, 0.1114,
    0.1116, 0.1110, 0.1112, 0.1110, 0, 0.1113, 0.1113, 0.1115, 0.1106, 0.1111,
    0.1113, 0.1107, 0.1106, 0.1105, 0.1113, 0, 0.1105, 0.1111, 0.1103, 0.1105,
    0.1112, 0.1112, 0.1114, 0.1114, 0.1113, 0.1105, 0, 0.1113, 0.1112, 0.1112,
    0.1115, 0.1112, 0.1112, 0.1111, 0.1115, 0.1111, 0.1113, 0, 0.1108, 0.1114,
    0.1106, 0.1112, 0.1112, 0.1113, 0.1106, 0.1103, 0.1112, 0.1108, 0, 0.1111,
    0.1112, 0.1114, 0.1113, 0.1114, 0.1111, 0.1105, 0.1112, 0.1114, 0.1111, 0
  ),
  byrow = TRUE, nrow = 10
)

res <- perplexity_similarities(
  perplexity = 9, verbose = FALSE,
  nn = find_nn(iris10,
    k = 10, method = "fnn",
    metric = "euclidean", n_threads = 0,
    verbose = FALSE
  )
)
expect_true(Matrix::isSymmetric(res))
expect_equal(as.matrix(res), Psymm9, tol = 1e-4, check.attributes = FALSE)


P_symm_6nn <- matrix(c(
  0, 0, 0.004227396, 0, 0.038581602, 0.016370215, 0.003972948, 0.037491042, 0, 0.007253571,
  0, 0, 0.020541010, 0.01457322, 0, 0, 0, 0.008117719, 0.008608916, 0.043891243,
  0.004227396, 0.020541010, 0, 0.04314614, 0.004242199, 0, 0.036275982, 0.004791681, 0.010952319, 0.015666352,
  0, 0.014573224, 0.043146139, 0, 0, 0, 0.018725165, 0, 0.032811238, 0.015644628,
  0.038581602, 0, 0.004242199, 0, 0, 0.016370215, 0.010365583, 0.031963895, 0, 0.003730662,
  0.016370215, 0, 0, 0, 0.016370215, 0, 0.002795087, 0.011902114, 0, 0.002562369,
  0.003972948, 0, 0.036275982, 0.01872517, 0.010365583, 0.002795087, 0, 0.006321792, 0.004717900, 0.003609179,
  0.037491042, 0.008117719, 0.004791681, 0, 0.031963895, 0.011902114, 0.006321792, 0, 0, 0.015406444,
  0, 0.008608916, 0.010952319, 0.03281124, 0, 0, 0.004717900, 0, 0, 0.004370167,
  0.007253571, 0.043891243, 0.015666352, 0.01564463, 0.003730662, 0.002562369, 0.003609179, 0.015406444, 0.004370167, 0
) * 10, nrow = 10, byrow = TRUE)
res <- perplexity_similarities(
  perplexity = 4, verbose = FALSE,
  nn = find_nn(iris10,
    k = 6, method = "fnn",
    metric = "euclidean", n_threads = 0,
    verbose = FALSE
  )
)
expect_true(Matrix::isSymmetric(res))
expect_equal(as.matrix(res), P_symm_6nn, tol = 1e-5, check.attributes = FALSE)

P_row <- matrix(c(
  0.000000e+00, 0.03254778, 0.04322171, 0.009522236, 4.179712e-01, 1.389888e-02, 0.03932256, 0.3802648571, 1.633620e-04, 0.06308741,
  1.336594e-02, 0.00000000, 0.21654628, 0.163906282, 4.387114e-03, 4.819686e-08, 0.02029701, 0.0618376045, 2.029701e-02, 0.49936271,
  9.381792e-04, 0.16129295, 0.00000000, 0.400023323, 9.381792e-04, 7.502536e-16, 0.29552576, 0.0143118438, 7.811162e-03, 0.11915861,
  3.912338e-06, 0.09596260, 0.48990584, 0.000000000, 3.912338e-06, 1.913538e-19, 0.09596260, 0.0009992458, 1.842071e-01, 0.13295484,
  4.498193e-01, 0.01739352, 0.04834632, 0.010928997, 0.000000e+00, 1.584988e-02, 0.07694327, 0.3403719404, 2.009270e-04, 0.04014584,
  3.505169e-01, 0.01137979, 0.01187167, 0.005542650, 3.505169e-01, 0.000000e+00, 0.02652673, 0.2200679860, 2.131340e-04, 0.02336422,
  1.894222e-02, 0.02233591, 0.51154696, 0.264602894, 5.091753e-02, 1.331050e-07, 0.00000000, 0.0834805843, 1.155348e-02, 0.03662029,
  4.467317e-01, 0.03468598, 0.04112888, 0.010524562, 3.177321e-01, 1.763500e-04, 0.03468598, 0.0000000000, 1.925167e-05, 0.11431520,
  8.735213e-04, 0.11962802, 0.21444851, 0.493687059, 8.735213e-04, 2.023140e-08, 0.08570012, 0.0059452421, 0.000000e+00, 0.07884399,
  1.960264e-02, 0.51417681, 0.15431120, 0.154311203, 6.986716e-03, 2.081749e-08, 0.01650597, 0.1299343209, 4.171117e-03, 0.00000000
), nrow = 10, byrow = TRUE)
# expected_sigmas <- c(0.3252233, 0.2679755, 0.1817380, 0.1751287, 0.3280264, 0.4861266, 0.2463306, 0.2422687, 0.3463065, 0.2411619)

res <- calc_row_probabilities_parallel(iris10_nn10$dist, iris10_nn10$idx,
  perplexity = 4,
  n_threads = 0
)$matrix
res <- nn_to_sparse(iris10_nn10$idx, as.vector(res),
  self_nbr = TRUE, max_nbr_id = nrow(iris10_nn10$idx)
)

expect_equal(as.matrix(res), P_row, tol = 1e-5, check.attributes = FALSE)

res <- calc_row_probabilities_parallel(iris10_nn10$dist, iris10_nn10$idx,
  perplexity = 4, n_threads = 1
)$matrix
res <- nn_to_sparse(iris10_nn10$idx, as.vector(res),
  self_nbr = TRUE, max_nbr_id = nrow(iris10_nn10$idx)
)
expect_equal(as.matrix(res), P_row, tol = 1e-5, check.attributes = FALSE)

iris_dup <- duplicated(x2m(iris))
uiris <- iris[!iris_dup, ]
# LargeVis-style iris normalization
normiris <- scale(x2m(uiris), center = TRUE, scale = FALSE)
normiris <- normiris / max(abs(normiris))
# niris10_nn149 <- dist_nn(dist(normiris), k = 149)
# expect_equal(1 / res$sigma ^ 2, Prow_niris_p150_k50_betas, tol = 1e-5)
# Taken from LargeVis C++ implementation
# Prow_niris_p150_k50_betas <-
#   c(
#     5.885742, 5.736816, 5.197266, 5.471191, 5.71875, 5.699707, 5.451172, 6.242188, 4.727051, 5.95459,
#     5.53418, 6.278809, 5.412598, 3.991699, 4.12207, 4.150879, 4.842285, 6.005859, 5.578369, 5.635254,
#     6.838379, 5.978516, 4.450684, 7.612305, 7.321289, 6.594238, 6.863281, 6.144043, 6.030762, 6.04248,
#     6.217285, 6.470703, 4.745117, 4.282471, 6.104004, 5.433594, 5.414551, 5.560547, 4.568359, 6.307129,
#     5.703125, 4.453369, 4.664551, 7.005859, 6.844238, 5.671875, 5.775879, 5.260254, 5.626465, 5.994629,
#     13.903809, 19.361816, 17.199951, 10.595215, 21.495117, 19.960938, 22.048828, 6.962402, 17.828125, 9.40918,
#     6.223511, 18.996094, 11.31665, 26.000977, 9.661377, 15.227539, 18.856934, 12.867676, 20.546875, 9.992432,
#     22.505859, 16.000977, 23.277344, 24.697266, 18.488281, 16.779785, 17.730957, 23.207031, 25.374023, 8.538818,
#     8.724121, 8.112671, 12.005859, 23.245605, 16.113281, 19.423828, 17.885254, 19.319336, 13.402344, 11.15625,
#     15.037109, 24.227539, 12.86377, 6.838623, 14.530762, 14.649414, 15.500977, 20.317383, 7.972168, 14.302246,
#     8.755371, 19.519531, 9.589355, 18.270508, 11.995361, 4.407959, 11.406738, 6.748779, 14.204102, 6.195068,
#     22.625977, 21.842773, 14.548828, 17.580078, 15.902832, 15.451172, 19.606445, 3.676147, 3.610718, 18.935059,
#     10.456055, 18.219238, 4.096863, 26.099609, 12.159668, 8.970703, 27.116211, 26.008301, 15.51416, 11.36377,
#     7.287109, 4.036987, 14.57373, 25.604492, 17.268066, 5.501221, 11.135986, 19.822266, 25.111816, 14.611328,
#     11.272705, 15.021484, 9.565918, 9.592529, 15.895508, 22.691895, 22.258789, 13.867188, 22.583008
#   )


# Taken from the LargeVis C++ implementation
Prow_iris_p150_k50_rowSums <- c(
  1.064902, 1.01981, 1.022902, 1.00269, 1.058712, 0.959587, 1.020604, 1.072308, 0.918501, 1.035426,
  1.010711, 1.055485, 1.00664, 0.874596, 0.840662, 0.782034, 0.960034, 1.065464, 0.91116, 1.029154,
  1.016113, 1.041956, 0.94594, 1.038197, 1.010267, 1.021842, 1.064241, 1.060124, 1.058187, 1.030837,
  1.03162, 1.022887, 0.938471, 0.876479, 1.042369, 1.031878, 0.992018, 1.047312, 0.93035, 1.069906,
  1.057901, 0.766783, 0.954321, 1.030951, 0.977892, 1.013204, 1.025217, 1.012253, 1.028178, 1.065557,
  0.84282, 1.103245, 0.981218, 0.927567, 1.190567, 1.135817, 1.11851, 0.745579, 1.074413, 0.892972,
  0.717667, 1.111475, 0.84308, 1.285817, 0.884184, 0.949937, 1.078251, 1.009963, 0.930185, 0.96844,
  1.107436, 1.031162, 1.161345, 1.165258, 1.076716, 1.029794, 1.008188, 1.173165, 1.26895, 0.857554,
  0.919873, 0.887521, 1.015009, 1.200669, 0.958936, 0.979149, 1.070616, 0.946976, 1.019518, 0.987791,
  1.007625, 1.257148, 1.038777, 0.754076, 1.078727, 1.049221, 1.098134, 1.148382, 0.74274, 1.084166,
  0.819342, 1.072723, 0.958094, 1.101275, 1.036348, 0.744063, 0.689913, 0.811003, 0.896867, 0.781658,
  1.162044, 1.199109, 1.110908, 0.923403, 0.831603, 1.055974, 1.167772, 0.662894, 0.653821, 0.899437,
  1.003726, 0.952751, 0.705241, 1.281862, 1.039594, 0.882058, 1.306628, 1.290206, 1.09357, 0.864005,
  0.824926, 0.663579, 1.063836, 1.242015, 0.843297, 0.769286, 0.907494, 1.14366, 1.252945, 1.073047,
  1.022525, 0.951965, 0.977462, 0.941184, 1.050544, 1.128182, 1.230836, 0.925821, 1.158545
)

res <- perplexity_similarities(
  perplexity = 50, n_threads = 0, verbose = FALSE,
  nn = find_nn(normiris,
    k = 149, method = "fnn",
    metric = "euclidean", n_threads = 0,
    verbose = FALSE
  )
)
expect_equal(Matrix::rowSums(res), Prow_iris_p150_k50_rowSums, tol = 1e-6)

res <- perplexity_similarities(
  perplexity = 50, n_threads = 1, verbose = FALSE,
  nn = find_nn(normiris,
    k = 149, method = "fnn",
    metric = "euclidean", n_threads = 1,
    verbose = FALSE
  )
)
expect_equal(Matrix::rowSums(res), Prow_iris_p150_k50_rowSums, tol = 1e-6)
